#-*-Python-*-

clipping_coeff = 0.2
mcts_clipping_coeff = 0.2
tanh_action_clipping = False

iterations_per_loop = 20
env_action_space = 2
env_state_space = 2
max_horizon = 2048

checkpoint_dir = '/tmp/tp-d/home/ml-for-systems'
summary_dir = '/tmp/tp-d/home/ml-for-systems'
train_state_file = '/tmp/tp-d/home/ml-for-systems/train_state.p'
warmstart_file = None

use_tpu = True

len_segment = 32

num_timesteps = 4194304
learning_rate = 3e-4
train_batch_size = 1024
num_iterations = 20480

tpu_num_shards = 8

ppo2_enable = True

# Loss values coefficients
policy_coeff = 1.0
value_coeff = 0.5
entropy_coeff = 0.0
mse_loss_coeff = 0.5
mcts_policy_coeff = 1.0

# Normalizations
obs_normalized = True
reward_normalized = True
clip_ob = 10.0
clip_rew = 10.0
epsilon = 1e-8
gamma = 0.99
lam = 0.95

keep_checkpoint_max = 1000

# MCTS hyperparams
mcts_enable = True
num_envs = 32
max_num_actions = 32
num_mcts_sim = 8
temperature = 80000
dirichlet_noise_alpha = 1.75
dirichlet_noise_weight = 0.5
mcts_value_discount = 0.99
c_puct = 2.0
mcts_start_step_frac = 0.1
mcts_end_step_frac = 1.2
mcts_sim_decay_factor = 1.0
mcts_sampling_frac = -0.5
num_moves_per_search = 1
build_dist_enable = True # Not enabled
max_q_enable = False # If True, using max(Q) for estimation
mcts_num_iterations = 512
mcts_collect_data_freq= 80

policy_hidden_layer_size=64
value_hidden_layer_size=64

policy_gradient_enable = False

random_action_sampling_freq = 0.1

ppo_mcts = False
return_estimate = 0

ppo_train.env_name = "Reacher-v2"
ppo_train.env_seed = 7

# Parameters for env.py
# ==============================================================================
ParallelEnv.gamma = %gamma
ParallelEnv.lam = %lam
ParallelEnv.tanh_action_clipping = %tanh_action_clipping
ParallelEnv.obs_normalized = %obs_normalized
ParallelEnv.reward_normalized = %reward_normalized
ParallelEnv.clip_ob = %clip_ob
ParallelEnv.clip_rew = %clip_rew
ParallelEnv.epsilon = %epsilon
ParallelEnv.mcts_enable = %mcts_enable
ParallelEnv.num_envs = %num_envs
ParallelEnv.mcts_start_step_frac = %mcts_start_step_frac
ParallelEnv.mcts_end_step_frac = %mcts_end_step_frac
ParallelEnv.num_iterations = %mcts_num_iterations
ParallelEnv.mcts_sim_decay_factor = %mcts_sim_decay_factor
ParallelEnv.mcts_sampling_frac = %mcts_sampling_frac
ParallelEnv.random_action_sampling_freq = %random_action_sampling_freq
ParallelEnv.mcts_collect_data_freq = %mcts_collect_data_freq
ParallelEnv.checkpoint_dir = %checkpoint_dir
ParallelEnv.len_segment = %len_segment

# Parameters for host_call_fn.py
# ==============================================================================
build_host_call_fn_every_n_global_steps.summary_dir = %summary_dir
build_host_call_fn.summary_dir = %summary_dir

# Parameters for tf_utils.py
# ==============================================================================
get_tpu_estimator.iterations_per_loop = %iterations_per_loop
get_tpu_estimator.keep_checkpoint_max = %keep_checkpoint_max
get_tpu_estimator.use_tpu = %use_tpu
get_tpu_estimator.train_batch_size = %train_batch_size

create_estimator.keep_checkpoint_max = %keep_checkpoint_max
create_estimator.iterations_per_loop = %iterations_per_loop

serving_input_fn.env_state_space = %env_state_space

# Parameters for ppo_loss.py
# ==============================================================================
# Policy Loss
ppo_policy_loss.clipping_coeff = %clipping_coeff
ppo_policy_loss.mcts_clipping_coeff = %mcts_clipping_coeff
ppo_policy_loss.tanh_action_clipping = %tanh_action_clipping
ppo_policy_loss.policy_gradient_enable = %policy_gradient_enable
# Value Loss
ppo2_value_loss.clipping_coeff = %clipping_coeff

# Parameters for ppo_model_fn.py
# ==============================================================================
PpoModelFn.env_action_space = %env_action_space
PpoModelFn.iterations_per_loop = %iterations_per_loop
PpoModelFn.num_timesteps = %num_timesteps
PpoModelFn.max_horizon = %max_horizon
PpoModelFn.learning_rate = %learning_rate
PpoModelFn.use_tpu = %use_tpu
PpoModelFn.ppo2_enable = %ppo2_enable
# Loss coefficients
PpoModelFn.policy_coeff = %policy_coeff
PpoModelFn.value_coeff = %value_coeff
PpoModelFn.mse_loss_coeff = %mse_loss_coeff
PpoModelFn.entropy_coeff = %entropy_coeff
PpoModelFn.mcts_policy_coeff = %mcts_policy_coeff
PpoModelFn.tpu_num_shards = %tpu_num_shards
PpoModelFn.warmstart_file = %warmstart_file

# Hidden Layer Size
PpoModelFn.policy_hidden_layer_size = %policy_hidden_layer_size
PpoModelFn.value_hidden_layer_size = %value_hidden_layer_size

PpoModelFn.ppo_mcts = %ppo_mcts

# Parameters for ppo_input_fn.py
# ==============================================================================
PpoInputFn.train_batch_size = %train_batch_size
PpoInputFn.iterations_per_loop = %iterations_per_loop
PpoInputFn.max_horizon = %max_horizon
PpoInputFn.env_state_space = %env_state_space
PpoInputFn.env_action_space = %env_action_space
PpoInputFn.use_tpu = %use_tpu
PpoInputFn.mcts_enable = %mcts_enable
PpoInputFn.checkpoint_dir = %checkpoint_dir
PpoInputFn.summary_dir = %summary_dir
PpoInputFn.train_state_file = %train_state_file
PpoInputFn.return_estimate = %return_estimate

# Parameters for ppo_trainer.py
# ==============================================================================
PpoTrainer.num_iterations = %num_iterations
PpoTrainer.iterations_per_loop = %iterations_per_loop
PpoTrainer.checkpoint_dir = %checkpoint_dir
PpoTrainer.keep_checkpoint_max = %keep_checkpoint_max
PpoTrainer.use_tpu = %use_tpu


# Parameters for mcts_player.py
# ==============================================================================
MCTSPlayer.env_action_space = %env_action_space
MCTSPlayer.max_num_actions = %max_num_actions
MCTSPlayer.num_mcts_sim = %num_mcts_sim
MCTSPlayer.temperature = %temperature
MCTSPlayer.num_envs = %num_envs
MCTSPlayer.debug = False
MCTSPlayer.build_dist_enable = %build_dist_enable
MCTSPlayer.max_q_enable = %max_q_enable
MCTSPlayer.num_moves_per_search = %num_moves_per_search

# Parameters for mcts_node.py
# ==============================================================================
MCTSNode.dirichlet_noise_alpha = %dirichlet_noise_alpha
MCTSNode.dirichlet_noise_weight = %dirichlet_noise_weight
MCTSNode.mcts_value_discount = %mcts_value_discount
MCTSNode.c_puct = %c_puct
MCTSNode.temperature = %temperature
MCTSNode.max_num_actions = %max_num_actions
MCTSNode.env_action_space = %env_action_space
MCTSNode.debug = False
